{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multicomplex Mathematics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard libraries\n",
    "import sys\n",
    "\n",
    "# Package from conda\n",
    "import numpy as np\n",
    "import sympy as sy\n",
    "\n",
    "# The multicomplex library\n",
    "import multicomplex as mcx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting started\n",
    "\n",
    "The function \n",
    "$$ f= x\\sin(x) $$\n",
    "has the derivative\n",
    "$$ f' = x\\cos(x)+\\sin(x) $$\n",
    "\n",
    "In order to evaluate this derivative with MultiComplex algebra, we define a function that takes multicomplex arguments, in this case the lambda function ``lambda z: z*np.sin(z)``, and we compare the derivative obtained with the exact solution.\n",
    "\n",
    "The function ``diff_mcx1`` takes as first argument the callable function of a single variable to be differentiated numerically, followed by the real value at which the derivative(s) are to be taken, and finally the number of derivatives that are to be taken.  All derivatives from order 1 to nderiv are returned in a list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "diff_mcx1(*args, **kwargs)\n",
      "Overloaded function.\n",
      "\n",
      "1. diff_mcx1(f: Callable[[mcx::MultiComplex<double>], mcx::MultiComplex<double>], x: float, numderiv: int, and_val: bool = False) -> List[float]\n",
      "\n",
      "2. diff_mcx1(f: Callable[[mcx::MultiComplex<double>], Tuple[mcx::MultiComplex<double>, mcx::MultiComplex<double>]], x: float, numderiv: int, and_val: bool = False) -> Tuple[List[float], List[float]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mcx.diff_mcx1.__doc__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = 0.1234\n",
    "exact = x*np.cos(x) + np.sin(x)\n",
    "nderiv = 4\n",
    "and_val = False\n",
    "mcxval = mcx.diff_mcx1(lambda z: z*np.sin(z), x, nderiv, and_val)[0]\n",
    "error = mcxval - exact\n",
    "error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Derivatives of a function of one variable\n",
    "\n",
    "Here we test some more interesting functions, and we show that the relative errors of these functions are all small, close to the numerical precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([mpf('2.3990722300687498e-16'), mpf('5.7028372834229652e-15'),\n",
       "       mpf('-2.440695170029012e-15')], dtype=object)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "array([mpf('1.7975675189132607e-16'), mpf('1.247770572193306e-16'),\n",
       "       mpf('0.0')], dtype=object)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "array([mpf('2.1730817635703391e-16'), mpf('2.1651710041894285e-16'),\n",
       "       mpf('-2.1372587912639966e-16')], dtype=object)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def deriver(var, func, nderiv):\n",
    "    # The python translation of the sympy function\n",
    "    # that takes numpy arguments (which allows for direct \n",
    "    # use of pymcx arguments)\n",
    "    py_func = sy.lambdify(var, func, 'numpy')\n",
    "\n",
    "    # Calculate \"exact\" derivatives from sympy and mpmath\n",
    "    exacts = []\n",
    "    for N in range(1, nderiv+1):\n",
    "        f = sy.lambdify(var,sy.diff(func,var,N),'mpmath')\n",
    "        exacts.append(f(0.1234))\n",
    "    exacts = np.array(exacts)\n",
    "\n",
    "    # Calculate MultiComplex derivatives for derivatives of order \n",
    "    # 1 to nderiv (inclusive) in one shot\n",
    "    mcxval = mcx.diff_mcx1(py_func, 0.1234, nderiv)\n",
    "    error = np.array(mcxval - exacts)/np.array(exacts)\n",
    "    return error\n",
    "\n",
    "# Calculate some derivatives of functions to test out our implementation\n",
    "y = sy.symbols('y')\n",
    "display(deriver(y, 1.0/sy.log(y), 3))\n",
    "display(deriver(y, sy.cos(y)*sy.sin(y)*sy.exp(y), 3))\n",
    "display(deriver(y, 1/(4+sy.cos(y)*sy.sin(y)*sy.exp(y)/sy.cosh(y))-sy.log(y), 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Derivatives of multivariate function\n",
    "\n",
    "In the multivariate case, the analysis is very similar.  We define a function that takes a _vector_ of multicomplex arguments, and in the diff_mcxN function specify the real values at which the derivatives are to be taken, and the number of times to take derivatives with respect to each of the independent variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.0005830791997394103, -0.0005830791997394103)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def func(zs):\n",
    "    x, y, z = zs\n",
    "    return np.cos(x)*np.sin(y)*np.exp(z)\n",
    "xs = [0.1234, 20.1234, -4.1234]\n",
    "orders = [1, 1, 2]\n",
    "def exact(zs):\n",
    "    x, y, z=zs\n",
    "    return -np.sin(x)*np.cos(y)*np.exp(z)\n",
    "\n",
    "mcx.diff_mcxN(func, xs, orders), exact(xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Length of x: 3 does not equal that of the order: 1",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-6-223fa6a2f688>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# A bad set of inputs (order too short), and an error returned\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmcx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiff_mcxN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m: Length of x: 3 does not equal that of the order: 1"
     ]
    }
   ],
   "source": [
    "# A bad set of inputs (order too short), and an error returned\n",
    "mcx.diff_mcxN(func, xs, [1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.0005830791997394104, -0.0005830791997394103, -0.0005830791997394103)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def func(zs):\n",
    "    x, y, z = zs\n",
    "    return np.cos(x)*np.sin(y)*np.exp(z)\n",
    "\n",
    "def der_sympy(xs):\n",
    "    \"\"\" \"Exact\" solution in higher precision mathematics with sympy \"\"\"\n",
    "    x, y, z = sy.symbols('x,y,z')\n",
    "    func = sy.cos(x)*sy.sin(y)*sy.exp(z)\n",
    "    f = sy.diff(sy.diff(sy.diff(func,x),z,2),y)\n",
    "    func = sy.lambdify((x,y,z),f,'mpmath')\n",
    "    return func(*xs)\n",
    "\n",
    "zs = []\n",
    "h = 1e-50\n",
    "L = 2\n",
    "xs = [0.1234, 20.1234, -4.1234]\n",
    "orders = [1, 1, 2]\n",
    "numderiv = sum(orders)\n",
    "\n",
    "# This block here is the implementation of the code that goes into\n",
    "# diff_mcxN, and a simplified variant of it goes into diff_mcx1\n",
    "for i in range(3):\n",
    "    c = np.zeros((2**numderiv,))\n",
    "    c[0] = xs[i]\n",
    "    assert(orders==[1,1,2]) # The orders are hard-coded here..\n",
    "    # <magic>\n",
    "    if i == 0:\n",
    "        c[1] = h\n",
    "    elif i == 1:\n",
    "        c[2] = h\n",
    "    elif i == 2:\n",
    "        c[4] = h\n",
    "        c[8] = h\n",
    "    # </magic>\n",
    "    zs.append(mcx.MultiComplex(c))\n",
    "\n",
    "# Three values are the derivatives from multicomplex, derivative from symbolic math, \n",
    "# and derivative from N-dimensional derivative function\n",
    "func(zs).get_coef()[2**L-1]/h**L, float(der_sympy(xs)), mcx.diff_mcxN(func, xs, orders)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time Profiling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.01 µs ± 9.72 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
      "1.38 µs ± 54.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
      "14 µs ± 952 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "22.9 µs ± 334 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "x = mcx.MultiComplex([1,2])\n",
    "%timeit x*x\n",
    "x = mcx.MultiComplex([1,2,3,4])\n",
    "%timeit x*x\n",
    "nderiv = 4\n",
    "%timeit mcx.diff_mcx1(lambda x: x*np.sin(x), 0.1234, nderiv)\n",
    "%timeit mcx.diff_mcx1(lambda x: x*np.sin(x)*np.cos(x), 0.1234, nderiv)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
